{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Implementation"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "This section demonstrates how to fit regression and decision trees with `scikit-learn`. We will again use the {doc}`tips </content/appendix/data>` dataset for the regression tree and the {doc}`penguins </content/appendix/data>` dataset for the classification tree."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "## Import packages\n",
    "import numpy as np \n",
    "import matplotlib.pyplot as plt\n",
    "import seaborn as sns\n",
    "import pandas as pd"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 1. Regression Tree"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Let's start by loading in the data. We'll keep things as `pandas` dataframes rather than `numpy` arrays now."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "## Load tips data\n",
    "tips = sns.load_dataset('tips')\n",
    "X = tips.drop(columns = 'tip')\n",
    "y = tips['tip']\n",
    "\n",
    "## Train-test split\n",
    "np.random.seed(1)\n",
    "test_frac = 0.25\n",
    "test_size = int(len(y)*test_frac)\n",
    "test_idxs = np.random.choice(np.arange(len(y)), test_size, replace = False)\n",
    "X_train = X.drop(test_idxs)\n",
    "y_train = y.drop(test_idxs)\n",
    "X_test = X.loc[test_idxs]\n",
    "y_test = y.loc[test_idxs]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We can then fit the regression tree with the `DecisionTreeRegressor` class. Unfortunately, `scikit-learn` does not currently support categorical predictors. Instead, we have to first convert these predictors to dummy variables. Note that this implies that splits of categorical variables can only separate one value from the rest. For instance, a variable with discrete values $a, b, c,$ and $d$ could not be split as $\\{a, b\\}$ versus $\\{c, d\\}$. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAbYAAAFTCAYAAABVgClBAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjMsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+AADFEAAAd3ElEQVR4nO3df5TddX3n8dfrzmTCMAnLGCapMKDVo7EsRYGri43rskJdqi6WI7Xa0lR3l5TFBey6uuqx626tZ6u0VXctYqJVKIo/ojnb5bQc/MW2zQrtBCgCIUtBhAAmQ0wgCWMmM/PeP+6dOAlz59e99/v53s99Ps6Zk5k7934/7+93knnl+7nv7/fjiBAAALmopC4AAIBWItgAAFkh2AAAWSHYAABZIdgAAFkh2AAAWSHYAABZIdgAJGf7DbbfkLoO5MFcoA0gJdsnSbq1/uUvR8SelPWg8xFsAJKy/aeStkjqkXRRRLwrcUnocAQb0OFsPyLp30XEt1PXApQB77Gh49k+MONjyvbYjK9/c4nbfMT2BfM85zW2/6/tp23/xPZW269c2l4AaJXe1AUAzYqIFdOfF3X2YvsESTdL+veSviapT9I/l3SoneMCmB9nbMie7ZNtf8P2qO0f2r5qxvf+s+3Hbe+3vcP2+bb/XNJpkv53/azvfbNs9qWSFBE3RcRkRIxFxK0Rcc+Mbb/f9kP1bd9v++IZ33vE9ntt32P7oO3P215j+6/qz/+27cFjnv+B+nb22v6C7eMWu7/HPG+F7Unbz5/x2Bm2n7S98pjnvt/25mMe+5Tt/9HoOM76w1ji+MCiRAQffGTzIekRSRfM+LoiaZuk/6LaWdWLJD0s6V9JWivpMUkn15/7Qkkvnm07s4xzgqQ9kq6X9CuSBmd5zq9JOrlew69LOijp+TO2f7ukNZJOkbRb0p2SzpK0XNJ3JX34mP26V9Kpkp4naaukPzi21rn2t8F+3CfpjTO+vlnSlbM87wWSnpV0Qv3rHklPSjp3ruO4gJ/Xgsbng4/FfHDGhty9UtJQRPx+RIxHxMOSNkl6m6RJ1ULkdNvLIuKRiHhoIRuNiGckvUZS1Lc3avsvbK+Z8ZyvR8QTETEVEV+V9KCkV83YzP+MiF0R8bikv5F0R0TcFRGHVOsSPOuYYT8dEY9FxE8kfVTS2xe5v7P5e0lnS5Lt10o6XdJnZ9nfH6kWvL9af+h1kp6NiNvVxHFc6PjAYhBsyN0LJJ1se9/0h6QPSloTEf8o6d2S/quk3ba/YvvkhW44IrZHxDsiYljSGaqdnX1y+vu219u+e8a4Z0g6acYmds34fGyWr1foaI/N+PxH9fEWvL8NduNIsEj6uKTfi4jxBs/9sn4Wpr9R/1pNHsfFjA8sCMGG3D0m6YcRceKMj5UR8QZJiogvR8RrVAuEkPSx+usWdR1MRDwg6YuqhZdsv0C1M6X/IGlVRJyo2lSim9iXU2d8fpqkJ2Z5zpz7O4u/l3S27bdI6pd00xzjf13SebaHJV2serBJcx7H+SxmfGBBCDbk7u8kPVNvbui33VNvUHil7bW2X2d7uaSfqnaWNFl/3S7V3p+ale2X2X5P/Ze8bJ+q2tnM7fWnDKj2C360/v13qh56TXiX7WHbz1PtLOyri9nfBtv8B0k/J+mPJb0/IqYaDR4Ro5Juk/QF1cJzuyTNcxzns+DxgYUi2JC1iJiU9K8lvULSDyU9Jelzkv6Jau8L/WH9sR9LWq1aYEjSf5f0ofp03n+aZdP7Jf0zSXfYPqhaoN0r6T31ce9X7Zf191ULyV9UreGjGV9W7dZTD9c//mCR+/sc9ffzfiDpkYj4qwXWcIFmnK1pjuNY7/L84LEbaWJ8YF7ceQToAO26Ps92n6R/lPTWeiNIoVKPjzxxxgZ0tw9L2powVFKPjwwRbEAXsn227aclvVbSld02PvLGVCQAICucsQEAskKwAQCy0hF397/wwgvjlltuSV0GAKA8Gt7soCPO2J566qnUJQAAOkRHBBsAAAtFsAEAspIk2Gz/ru37bN9r+6ZGCyYCALBYhQeb7VMkXSWpGhFnqLZgYaO1ogAAWJRUU5G9kvpt90o6XrMvvwEAwKIVHmz11YL/SNKjqi0t/3RE3Fp0HQCAPKWYihyU9GZJP6/aCsADti+d5XkbbI/YHhkdHS26TABAh0oxFXmBaosUjkbEYUnflPRLxz4pIjZGRDUiqkNDQ4UXCQDoTCnuPPKopHNtH6/aSrvnSxpJUAcAoCBTU6E9B8c1PjGpvt4erRroU6XS8OYhTSk82CLiDtubJd0paULSXZI2Fl0HAKAYU1OhHbv267IbRrRz75iGB/u1aX1Va9esbEu4dcSyNdVqNUZGOKkDgE40uv+QLr52q3buHTvy2PBgv7ZcsU5DK5cvdbOdfa9IAEDnGp+YPCrUJGnn3jGNT0y2ZTyCDQDQVn29PRoe7D/qseHBfvX19rRlPIINANBWqwb6tGl99Ui4Tb/Htmqgry3jdcR6bACAzlWpWGvXrNSWK9bl2RUJAOg+lYqbaRRZ3FiFjAIAQEEINgBAVgg2AEBWCDYAQFYINgBAVgg2AEBWCDYAQFYINgBAVgg2AEBWCDYAQFYINgBAVgg2AEBWCDYAQFYINgBAVgg2AEBWCDYAQFYINgBAVgg2AEBWCDYAQFYINgBAVgg2AEBWCDYAQFYINgBAVgoPNttrbd894+MZ2+8uug4AQJ56ix4wInZIeoUk2e6R9LikLUXXAQDIU+qpyPMlPRQRP0pcBwAgE6mD7W2SbkpcAwAgI8mCzXafpIskfb3B9zfYHrE9Mjo6WmxxAICOlfKM7Vck3RkRu2b7ZkRsjIhqRFSHhoYKLg0A0KlSBtvbxTQkAKDFkgSb7eMl/bKkb6YYHwCQr8Lb/SUpIp6VtCrF2ACAvKXuigQAoKWSnLEBnWRqKrTn4LjGJybV19ujVQN9qlTc8tcAaA2CDZjD1FRox679uuyGEe3cO6bhwX5tWl/V2jUrGwbVUl4DoHWYigTmsOfg+JGAkqSde8d02Q0j2nNwvKWvAdA6BBswh/GJySMBNW3n3jGNT0y29DUAWodgA+bQ19uj4cH+ox4bHuxXX29PS18DoHUINmAOqwb6tGl99UhQTb9ftmqgr6WvAdA6jojUNcyrWq3GyMhI6jLQpeiKBEqp4T8ouiKBeVQq1tDK5W1/DYDWYCoSAJAVgg0AkBWCDQCQFYINAJAVgg0AkBWCDQCQFYINAJAVgg0AkBWCDQCQFYINAJAVgg0AkBWCDQCQFYINAJAVgg0AkBWCDQCQFYINAJAVgg0AkBWCDQCQFYINAJCVJMFm+0Tbm20/YHu77VenqAMAkJ/eRON+StItEXGJ7T5JxyeqAwCQmcKDzfYJkl4r6R2SFBHjksaLrgMAkKcUU5EvkjQq6Qu277L9OdsDxz7J9gbbI7ZHRkdHi68SANCRUgRbr6SzJX0mIs6SdFDS+499UkRsjIhqRFSHhoaKrhEA0KFSBNtOSTsj4o7615tVCzoAAJpWeLBFxI8lPWZ7bf2h8yXdX3QdAIA8peqKvFLSl+odkQ9LemeiOgAAmUkSbBFxt6RqirEBAHnjziMAgKwQbACArBBsAICsEGwAgKwQbACArBBsAICsEGwAgKwQbACArBBsAICsEGwAgKykulckUDpTU6E9B8c1PjGpvt4erRroU6Xipre1rLei3oo1Nl7b7mD/Mu0dO9zUOK2sFcgNwQaoFhQ7du3XZTeMaOfeMQ0P9mvT+qrWrlm5pNA5dlvXXHKmPn7LDg2t7NNV579Ul9+4bcnjtLJWIEeOiNQ1zKtarcbIyEjqMpCx0f2HdPG1W7Vz79iRx4YH+7XlinUaWrm8Jdv6vTedLkn6yM33NzVOK2sFOljD/8VxxgZIGp+YPCooJGnn3jGNT0y2bFsn9i878nkz47SyViBHNI8Akvp6ezQ82H/UY8OD/err7WnZtvaNHda+scNNj9PKWoEcEWyApFUDfdq0vnokMKbft1o10NeSbV1zyZm67raH9I1tj+m6S89papxW1grkiPfYgDq6IoGOwntswHwqFbes+WLWbQ387NNmx2llrUBumIoEAGSFYAMAZIVgAwBkhWADAGSFYAMAZIVgAwBkhWADAGSFYAMAZIVgAwBkJcmdR2w/Imm/pElJExFRTVEHACA/KW+p9S8j4qmE46Ng3N8QQBG4VyQKwarPAIqS6j22kHSr7W22NySqAQXac3D8SKhJtYUxL7thRHsOjieuDEBuUp2xrYuIJ2yvlvQt2w9ExF/PfEI98DZI0mmnnZaiRrQQqz4DKEqSM7aIeKL+525JWyS9apbnbIyIakRUh4aGii4RLcaqzwCKUniw2R6wvXL6c0mvl3Rv0XWgWKz6DKAoKaYi10jaYnt6/C9HxC0J6kCBKhVr7ZqV2nLFOroiAbRV4cEWEQ9LennR4yI9Vn0GUATuPAIAyArBBgDICsEGAMgKdx5JLPfbTJVl/4quoyz7DXQjgi2h3G8zVZb9K7qOsuw30K2Yikwo99tMlWX/iq6jLPsNdCuCLaHcbzNVlv0ruo6y7DfQrQi2hHK/zVRZ9q/oOsqy30C3ItgSyv02U2XZv6LrKMt+A93KEZG6hnlVq9UYGRlJXUZb5N49V5b9oysSyE7Df1B0RSaW+22myrJ/RddRlv0GutGipiJtf9s293kEAJTWnMFm+3TbN8546H2SPmH7C7af397SAABYvPnO2L4j6UPTX0TEnRHxOkk3S7rF9odt9zd8NQAABZsv2F4v6aMzH3BtIbUdkj4j6UpJD9r+rfaUBwDA4swZbBHxg4j4zemvbf+tpMclfULSKZLeIek8Sa+yvbF9ZQIAsDCL7Yq8XNJ98dxrBK60vb1FNQEAsGSLCraIuHeOb7+xyVoAAGhay+48EhEPt2pbAAAsFbfUAgBkhWADAGSFYAMAZIVgAwBkhWADAGSFYAMAZIVgAwBkhfXYMCcWzATQaQg2NDQ1Fdqxa78uu2FEO/eOaXiwX5vWV7V2zUrCDUBpJZuKtN1j+y7bN6eqAXPbc3D8SKhJ0s69Y7rshhHtOTieuDIAaCzlGdvVkrZLOiFhDZjD+MTkkVCbtnPvmMYnJpe0PaY188XPFmWSJNhsD6t20+SPSvqPKWrA/Pp6ezQ82H9UuA0P9quvt2fR22JaM1/8bFE2qaYiPynpfZKmEo2PBVg10KdN66saHqwtkj79C2vVQN+it8W0Zr742aJsCj9js/0mSbsjYpvt8+Z43gZJGyTptNNOK6g6zFSpWGvXrNSWK9Y1PcXU6mlNlAc/W5RNijO2dZIusv2IpK9Iep3tG499UkRsjIhqRFSHhoaKrhF1lYo1tHK5Thk8XkMrly95aml6WnOmpU5rolz42aJsCg+2iPhARAxHxAslvU3SdyPi0qLrQLFaOa2JcuFni7LhOjYUopXTmigXfrYom6TBFhG3SbotZQ0ozvS0JvLDzxZlwr0iAQBZYSoSc8rpwtuc9gVAYwQbGsrpwtuc9gXA3JiKREM5XXib074AmBvBhoZyuvA2p30BMDeCDQ3ldOFtTvsCYG4EGxrK6cLbnPYFwNwcEalrmFe1Wo2RkZHUZXSlnDoJc9oXAGr4j5euSMwppwtvc9oXAI0xFQkAyArBBgDICsEGAMgKwQYAyArNI5mh8w9AtyPYMsL9EAGAqciscD9EACDYssL9EAGAYMsK90MEAIItK9wPEQBoHkmulV2MlYq1ds1KbbliXVd1RXZzJ2g37zvQCMGWUDu6GLvtfojd3AnazfsOzIWpyIToYmxeNx/Dbt53YC4EW0J0MTavm49hN+87MBeCrUlTU6HR/Yf0+N5nNbr/kKamFr6+XTu6GJuppxMt5hjmdmzoggVmR7A1Yfo9jouv3ap1H/ueLr52q3bs2r/gX5it7mJstp5OtNBjmOOxoQsWmB0raDdhdP8hXXzt1qOmg4YH+7XlinULbuBoZVdbK+rpRAs5hrkeG7oi0cVYQbsdWvEeRyu7GLv1PZeFHMNcj023dcECC8FUZBPK9h5H2eopE44N0D0KDzbbx9n+O9v/YPs+2/+t6BpapWzvcZStnjLh2ADdo/D32Gxb0kBEHLC9TNLfSro6Im5v9JqyvscmNf8eR6vfIyn79lLKaV8AlOg9tqgl6YH6l8vqH+XvYGmgmfc4yn7nkdzubMH7UUB3SPIem+0e23dL2i3pWxFxR4o6Uiv7nSPKXh8AzCZJsEXEZES8QtKwpFfZPuPY59jeYHvE9sjo6GjxRRag7J16Za8PAGaTtCsyIvZJuk3ShbN8b2NEVCOiOjQ0VHhtRSh7p17Z6wOA2aToihyyfWL9835JF0h6oOg6yqDsnXplrw8AZpOiK/JMSddL6lEtWL8WEb8/12ua7Yqcqxsudafc4cOT2n3gkCamQr0Va/WK5Vq2bPFnRO3Yj6mp0OP7ntWhiVDF0lRIfb3W8p6KKpVKw+PY39ejianQ4Ympo2pJfayL1E37CiRSqq7IeySdVdR4c3X2SUra9TcxMaUduw/o8hu3HRn/ukvP0cvWrFRv78JPptvVvbhvbFw7947pvZvvObLday45U5L03s33zHoch1Ys1/suXHvUazatr+olQyv04OiBbDos55JbNynQabK/88hcnX2pu/52Hzh0JNSmx7/8xm3afeDQorbTrv0YG588ElDT233v5nv0cycc1/A4Xn7ei5/zmstuGNHuA4e6psMy9d8roNtlf6/I+Tr7Unb9HZ6cmnX8icmpRW2nXd2LkxGzbneyPn0923E8sX9Zw33qlg5LukmBtLI/Y5ursy9119+ynsqs4/f2LO7H0q79OG7Z7Nv98dM/PWqMmePvGzvccJ+6pcMy9d8roNtlH2xzdfal7vpbvWK5rrv0nKPGv+7Sc7R6xeLujtGu/ThpYPlztvuJt75cf3zr/2t4HK+77SFdc8mZz6ll9YrnbuuGf/MqhSKbhT+npf57BXS7rliPrcxdkRMTU7WuyMkp9fZUtHrF8kU1jkxr137M3O6y3op6K9bY+NzHcSFdkf19Pdr1zKFsGyxS/70CukDDf1BdEWwon1wX/gRQmIbBlv1UJMqJBgsA7UKwIQkaLAC0C8GGJGiwANAu2V/H1mm6pemgUrHWrlmpLVesy35fARSLYCuRbrsVEwt/AmgHpiJLhFsxAUDzCLYSoVMQAJpHsJUInYIA0DyCrUToFASA5tE8UiJ0CgJA8wi2kqFTEACaw1QkACArBBsAICsEGwAgKwQbACArBBsAICsEGwAgKwQbACArXMfWZvMtQ9Mty9QAQFEItjaabxmablumBgCKwFRkG823DA3L1ABA63HG1kbzLUPTjmVqmNoE0O0KP2Ozfart79nebvs+21cXXUNR5luGptXL1ExPbV587Vat+9j3dPG1W7Vj135NTcXSdgAAOlCKqcgJSe+JiF+QdK6kd9k+PUEdizI1FRrdf0iP731Wu5/5qXY/81P9aM9BPbFvTBMTU7O+Zr5laBayTM3McUf3H5ozpFo1tbmYMZtV5FgAukPhU5ER8aSkJ+uf77e9XdIpku4vupaFmq3J45pLztTHb9mh0QOHdN2l5+hla1aqt/e5/09Yc8JyfXXDuZoM6bhlFZ00sPzI1GClYr1kaIW+9juv1uHJKS3rqWj1ip99f7HNJa2Y2iyyoYXmGQDtkLR5xPYLJZ0l6Y6UdcxntjOh926+R5ef92Lt3Dumy2/cpt0HDh31mulf2hd9ujYt+BubbteeA+PPec6Dowf01s9+X//imtv01s9+Xw+OHjhy1rLYM7BWTG0W2dBC8wyAdkgWbLZXSPqGpHdHxDOzfH+D7RHbI6Ojo8UXOEOjM6ET+5cd+Xxi8ujpyIX80p7vOYs9A2vFCtztaGgpw1gAukeSrkjby1QLtS9FxDdne05EbJS0UZKq1WrSN16mz4Rm/hIeHuzXvrHDRz7v7Tn6/wgL+aU933MajdvoDKwVK3AvdsxmFDkWgO6RoivSkj4vaXtE/EnR4y/FbGdC11xypq677SEND/brukvP0eoVR696vZBpwfmes5QzsOkVuE8ZPF5DK5cv+r2qVpz1lXEsAN3DEcWeDNl+jaS/kfQDSdPzdx+MiL9s9JpqtRojIyNFlNfQzOvDltXPzn56eFK99YaPYxtHFtIYsdDnFH1dWpFjct0dgCVq+Iui8GBbijIE21Is5Jc2v9gBYEka/qLkziNtND0t2OxzAAALx70iAQBZIdgAAFkh2AAAWSHYAABZIdgAAFkh2AAAWaHdv00Wen0a17EBQGsRbC00HVJTU1N66uC4fufPt825HEsRy7YQnAC6DVORLTJz9eq7dz59JNSkxsuxtHvZFlbUBtCNCLYWmRlSJ/YvW9ByLO1etoX1zgB0I6YiW2RmSO0bO7yg5Vj6env0+tNX6y3nnKoT+5dp39hhfWPbYy1btoX1zgB0I87YWmTmEjTX3faQPvaWM+ddjmWwf5muOv+l+sjN9+vXN96uj9x8v646/6UarC9g2sqaprHeGYDccXf/Fjm2EeT1p6/Wh954unoqbti0Mbr/kC6+dutzzuy2XLFuzhsjL6bjstnmFJpPAJQUd/dvt6WsXr2UqcLFhFWzK2oX0bUJAK3GVGQLLXb16qVMFS62IaSZFbVpPgHQiQi2hFYN9GnT+uq878XNVGRDCM0nADoRU5EJLWWqcPosb76Oy1YociwAaBXO2BJb7FThUs7ylqrIsQCgVeiK7EBFdirSFQmgpOiKzMn0WV5uYwFAKzAVCQDISlecsbViOo0pOQDoDNkHW6vuvsGFygDQGbKfimzFRcZcqAwAnSP7YGvFRcZcqAwAnSP7YGvFHe65Sz4AdI7sg60VFxlzoTIAdI4kF2jb/jNJb5K0OyLOmO/5zV6gTVckAGSndBdof1HSpyXdUMRgrbjImAuVAaAzJJmKjIi/lvSTFGMDAPJW2vfYbG+wPWJ7ZHR0NHU5AIAOUdpgi4iNEVGNiOrQ0FDqcgAAHaK0wQYAwFIQbACArCQJNts3Sfq+pLW2d9r+tynqAADkJ0m7f0S8PcW4AID8MRUJAMhKkjuPLJbtUUk/Sl1Hm50k6anURXQojt3SceyWjmO3NK06bk9FxIWzfaMjgq0b2B6JiGrqOjoRx27pOHZLx7FbmiKOG1ORAICsEGwAgKwQbOWxMXUBHYxjt3Qcu6Xj2C1N248b77EBALLCGRsAICsEW2K2T7X9Pdvbbd9n++rUNXUS2z2277J9c+paOontE21vtv1A/e/eq1PX1Cls/2793+q9tm+yfVzqmsrK9p/Z3m373hmPPc/2t2w/WP9zsNXjEmzpTUh6T0T8gqRzJb3L9umJa+okV0vanrqIDvQpSbdExMskvVwcwwWxfYqkqyRVI+IMST2S3pa2qlL7oqRjrzV7v6TvRMRLJH2n/nVLEWyJRcSTEXFn/fP9qv2COSVtVZ3B9rCkN0r6XOpaOontEyS9VtLnJSkixiNiX9qqOkqvpH7bvZKOl/RE4npKq8Gi0m+WdH398+sl/WqrxyXYSsT2CyWdJemOtJV0jE9Kep+kqdSFdJgXSRqV9IX6NO7nbA+kLqoTRMTjkv5I0qOSnpT0dETcmraqjrMmIp6Uav+xl7S61QMQbCVhe4Wkb0h6d0Q8k7qesrP9Jkm7I2Jb6lo6UK+ksyV9JiLOknRQbZgOylH9/aA3S/p5SSdLGrB9adqqcCyCrQRsL1Mt1L4UEd9MXU+HWCfpItuPSPqKpNfZvjFtSR1jp6SdETE9M7BZtaDD/C6Q9MOIGI2Iw5K+KemXEtfUaXbZfr4k1f/c3eoBCLbEbFu19zq2R8SfpK6nU0TEByJiOCJeqNqb99+NCP7nvAAR8WNJj9leW3/ofEn3Jyypkzwq6Vzbx9f/7Z4vGm8W6y8k/Xb989+W9L9aPUCS9dhwlHWSfkvSD2zfXX/sgxHxlwlrQv6ulPQl232SHpb0zsT1dISIuMP2Zkl3qtbRfJe4A0lD9UWlz5N0ku2dkj4s6Q8lfa2+wPSjkn6t5eNy5xEAQE6YigQAZIVgAwBkhWADAGSFYAMAZIVgAwBkhWADAGSFYAMAZIVgA0rO9i/a3jrj67NtfzdlTUCZcYE2UHK2K6otjXJKREza/p5qa/jdmbg0oJS4pRZQchExZfs+Sf/U9kskPUqoAY0RbEBnuF21+4peoeeuSAxgBoIN6Ay3S/qipD+tL3YJoAHeYwM6QH0K8v9IeklEHExdD1BmdEUCneFqSR8g1ID5EWxAidl+se0HJPVHxPWp6wE6AVORAICscMYGAMgKwQYAyArBBgDICsEGAMgKwQYAyArBBgDICsEGAMgKwQYAyMr/ByfjNVgylwIuAAAAAElFTkSuQmCC\n",
      "text/plain": [
       "<Figure size 504x360 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "from sklearn.tree import DecisionTreeRegressor\n",
    "\n",
    "## Get dummies\n",
    "X_train = pd.get_dummies(X_train, drop_first = True)\n",
    "X_test = pd.get_dummies(X_test, drop_first = True)\n",
    "\n",
    "## Build model\n",
    "dtr = DecisionTreeRegressor(max_depth = 7, min_samples_split = 5)\n",
    "dtr.fit(X_train, y_train)\n",
    "y_test_hat = dtr.predict(X_test)\n",
    "\n",
    "## Visualize predictions\n",
    "fig, ax = plt.subplots(figsize = (7, 5))\n",
    "sns.scatterplot(y_test, y_test_hat)\n",
    "ax.set(xlabel = r'$y$', ylabel = r'$\\hat{y}$', title = r'Test Sample $y$ vs. $\\hat{y}$')\n",
    "sns.despine()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 2. Classification Tree"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The classification tree implementation in `scikit-learn` is nearly identical. The corresponding code is provided below. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "## Load penguins data\n",
    "penguins = sns.load_dataset('penguins')\n",
    "penguins = penguins.dropna().reset_index(drop = True)\n",
    "X = penguins.drop(columns = 'species')\n",
    "y = penguins['species']\n",
    "\n",
    "## Train-test split\n",
    "np.random.seed(1)\n",
    "test_frac = 0.25\n",
    "test_size = int(len(y)*test_frac)\n",
    "test_idxs = np.random.choice(np.arange(len(y)), test_size, replace = False)\n",
    "X_train = X.drop(test_idxs)\n",
    "y_train = y.drop(test_idxs)\n",
    "X_test = X.loc[test_idxs]\n",
    "y_test = y.loc[test_idxs]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.9036144578313253"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from sklearn.tree import DecisionTreeClassifier\n",
    "\n",
    "## Get dummies\n",
    "X_train = pd.get_dummies(X_train, drop_first = True)\n",
    "X_test = pd.get_dummies(X_test, drop_first = True)\n",
    "\n",
    "## Build model\n",
    "dtc = DecisionTreeClassifier(max_depth = 10, min_samples_split = 10)\n",
    "dtc.fit(X_train, y_train)\n",
    "y_test_hat = dtc.predict(X_test)\n",
    "\n",
    "## Observe Accuracy\n",
    "np.mean(y_test_hat == y_test)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
